-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversion from FP32 model to Mixed Precision model #15118
Conversation
40b33cd
to
3e8ca54
Compare
@ptrendx @pengzhao-intel @samskalicky @larroy @ZhennanQin Thank you for your review ! I have addressed your comments. |
7e9633d
to
43bc0c9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the approach and architecture look good, just refining a few C++ formalisms remain in my view.
src/c_api/c_api_symbolic.cc
Outdated
@@ -810,6 +810,191 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, | |||
API_END_HANDLE_ERROR(delete s); | |||
} | |||
|
|||
// helper function to add mapping of node_name -> dtype map | |||
// for the given indexed graph and inferred_dtypes | |||
inline void _SetInputDTypes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use anon namespace or static function for helpers, not inline.
const std::unordered_map<std::string, int>& node_name_dtype_map, | ||
const std::unordered_map<std::string, int>& node_without_dtype_map, | ||
const std::unordered_set<std::string>& model_params, | ||
const std::vector<nnvm::NodePtr>& args) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args being const and modified through NodePtr is misleading, a documentation bit would help.
src/nnvm/low_precision_pass.cc
Outdated
} | ||
|
||
// add amp_cast node between curr_node and input | ||
void AddCastNode(const nnvm::NodeEntry &e, const std::string &suffix, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static or anon namespace to avoid additional uneccesary linker symbols.
src/nnvm/low_precision_pass.cc
Outdated
} | ||
|
||
// get suffix for a node entry so that it can be used for amp_cast/amp_multicast node name | ||
std::string GetSuffix(const nnvm::NodeEntry &node_entry, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment about anon static, and I guess applies to other places.
std::unordered_set<std::string> widest_dtype_ops; | ||
std::unordered_set<std::string> excluded_syms; | ||
std::unordered_set<std::string> model_params; | ||
std::unordered_map<std::string, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to add comment on what this thing represents, or maybe a typedef with a doc. It would help with code maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice pr.
From an offline review done by sudipta@, feedback was provided that it is important to for users to be able to obtain models with params casted wherever possible. After additional discussion with @ptrendx , we decided to add additional graph pass, which would go through all inputs of amp_cast and amp_multicast to infer the dtypes of the input nodes wherever possible. I have added the support for the same in the recent commits. |
Description
Users want to bring a FP32 model, to convert it to a mixed precision model to run inference on it. This work leverages the existing work already done by @ptrendx, @Caenorst with AMP and tries to provide users with conversion APIs to convert their symbolic model or gluon model to a mixed precision model. This also adds the necessary C APIs, so that similar support for conversion APIs can be added in other frontends.
Thanks to all involved in prior discussions, suggestions and design review of the project (sincere apologies if I missed someone):
@ptrendx, @rahul003, Sudipta Sengupta (AWS) (@sudiptasengupta), Poorna Chand Srinivas Perumalla (@bhagatindia) (AWS), Wei Xiao (AWS), @Vikas89, @lupesko, @pengzhao-intel , @ZhennanQin
API Additions (Python)
Refactoring or Existing code changes
Module API (executor_group.py)
Gluon API (parameter.py)
Test Utils API (test_utils.py)
AMP Tests
Additions
Fixes : #14584
Doc: https://tinyurl.com/y42kx9hl
Other Flaky Tests/Bug Fixes:
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments